{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(tutorial-topi)=\n",
        "# TOPI 简介\n",
        "**作者**: [Ehsan M. Kermani](https://github.com/ehsanmok)\n",
        "\n",
        "这是一个关于 TVM Operator Inventory（TOPI）的介绍性教程。TOPI 提供了 numpy 风格的通用算子和调度，比 TVM 的抽象程度更高。在本教程中，将看到 TOPI 如何将从 TVM 中编写模板代码中拯救出来。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "import tvm.testing\n",
        "from tvm import te\n",
        "from tvm import topi\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 基本例子\n",
        "\n",
        "让我们再来看看行之和的算子（相当于 `B = numpy.sum(A, axis=1)`）为了计算二维 TVM 张量 A 的行之和，应该指定符号算子以及调度，如下所述："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = te.var(\"n\")\n",
        "m = te.var(\"m\")\n",
        "A = te.placeholder((n, m), name=\"A\")\n",
        "k = te.reduce_axis((0, m), \"k\")\n",
        "B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name=\"B\")\n",
        "s = te.create_schedule(B.op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "并以人类可读的格式来检查 IR 代码，我们可以做到"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "PrimFunc([A]) attrs={\"from_legacy_te_schedule\": (bool)1, \"global_symbol\": \"main\", \"tir.noalias\": (bool)1} {\n",
              "  allocate B[float32 * n], storage_scope = global\n",
              "  for (i, 0, n) {\n",
              "    B[i] = 0f\n",
              "    for (k, 0, m) {\n",
              "      B[i] = (B[i] + A[((i*stride) + (k*stride))])\n",
              "    }\n",
              "  }\n",
              "}"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "m = tvm.lower(s, [A], simple_mode=True)\n",
        "m[\"main\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "然而，对于这样普通的算子，我们不得不自己定义 `reduce` 轴，以及用 `te.compute` 进行显式计算。想象一下，对于更复杂的操作，需要提供多少细节。幸运的是，可以用简单的 `topi.sum` 替换这两行，就像 `numpy.sum` 一样。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "PrimFunc([A]) attrs={\"from_legacy_te_schedule\": (bool)1, \"global_symbol\": \"main\", \"tir.noalias\": (bool)1} {\n",
              "  allocate A_red[float32 * n], storage_scope = global\n",
              "  for (ax0, 0, n) {\n",
              "    A_red[ax0] = 0f\n",
              "    for (k1, 0, m) {\n",
              "      A_red[ax0] = (A_red[ax0] + A[((ax0*stride) + (k1*stride))])\n",
              "    }\n",
              "  }\n",
              "}"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "C = topi.sum(A, axis=1)\n",
        "ts = te.create_schedule(C.op)\n",
        "m = tvm.lower(ts, [A], simple_mode=True)\n",
        "m[\"main\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Numpy 风格的运算符重载\n",
        "\n",
        "我们可以用 `topi.broadcast_add` 来添加两个张量，它们有正确的（可广播的特定）形状。甚至更短，TOPI 为这种常见的操作提供了运算符重载。比如说："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x, y = 100, 10\n",
        "a = te.placeholder((x, y, y), name=\"a\")\n",
        "b = te.placeholder((y, y), name=\"b\")\n",
        "c = a + b  # same as topi.broadcast_add\n",
        "d = a * b  # same as topi.broadcast_mul"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "用同样的语法重载，TOPI 处理将一个原语（`int`, `float`）广播到一个张量 `d - 3.14`。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 通用的调度和融合操作\n",
        "\n",
        "到目前为止，我们已经看到了一个例子，说明 TOPI 如何使我们免于在低级别的 API 中编写显式计算。但它并没有在这里停止。我们仍然像以前一样进行调度。TOPI 还提供了更高层次的调度方案，这取决于特定的环境。例如，对于 CUDA，我们可以只用 `topi.sum` 来调度以下一系列以 `topi.generic.schedule_reduce` 结束的操作"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(a_1: handle, b_1: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {a: Buffer(a_2: Pointer(float32), float32, [10000], []),\n",
            "             b: Buffer(b_2: Pointer(float32), float32, [100], [])}\n",
            "  buffer_map = {a_1: a, b_1: b}\n",
            "  preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [100, 10, 10], []), b_1: b_3: Buffer(b_2, float32, [10, 10], [])} {\n",
            "  allocate(T_divide_red: Pointer(global float32), float32, [1]), storage_scope = global;\n",
            "  attr [IterVar(threadIdx.x: int32, [0:1024], \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 1024;\n",
            "  allocate(T_divide_red.rf: Pointer(local float32), float32, [1]), storage_scope = local;\n",
            "  allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local {\n",
            "    T_divide_red.rf_1: Buffer(T_divide_red.rf, float32, [1], [], scope=\"local\", align=4)[0] = 0f32\n",
            "    for (k0.k1.fused.k2.fused.outer: int32, 0, 10) {\n",
            "      if @tir.likely((((((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625) && (((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625)) && (((k0.k1.fused.k2.fused.outer*64) + floordiv(threadIdx.x, 16)) < 625)), dtype=bool) {\n",
            "        T_divide_red.rf_1[0] = (T_divide_red.rf_1[0] + (((a[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)] + b[floormod(((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x), 100)]) + (a[((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x)]*b[floormod(((k0.k1.fused.k2.fused.outer*1024) + threadIdx.x), 100)]))*0.5f32))\n",
            "      }\n",
            "    }\n",
            "    attr [meta[tir.CommReducer][0]] \"reduce_scope\" = @tir.reinterpret(0u64, dtype=handle);\n",
            "    @tir.tvm_thread_allreduce(1u32, T_divide_red.rf_1[0], True, reduce_temp0_1: Buffer(reduce_temp0, float32, [1], [], scope=\"local\")[0], threadIdx.x, dtype=handle)\n",
            "    if (threadIdx.x == 0) {\n",
            "      T_divide_red_1: Buffer(T_divide_red, float32, [1], [], align=4)[0] = reduce_temp0_1[0]\n",
            "    }\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/media/pc/data/4tb/lxw/anaconda3/envs/tvm-mxnet/lib/python3.10/site-packages/tvm/target/target.py:347: UserWarning: Try specifying cuda arch by adding 'arch=sm_xx' to your target.\n",
            "  warnings.warn(\"Try specifying cuda arch by adding 'arch=sm_xx' to your target.\")\n"
          ]
        }
      ],
      "source": [
        "e = topi.elemwise_sum([c, d])\n",
        "f = e / 2.0\n",
        "g = topi.sum(f)\n",
        "with tvm.target.cuda():\n",
        "    sg = topi.cuda.schedule_reduce(g)\n",
        "    print(tvm.lower(sg, [a, b], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "正如你所看到的，预定的计算阶段已经被积累起来，我们可以通过以下方式检查它们"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[stage(a, placeholder(a, 0x55dde093b7f0)), stage(b, placeholder(b, 0x55dddfadec40)), stage(T_add, compute(T_add, body=[(a[ax0, ax1, ax2] + b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_multiply, compute(T_multiply, body=[(a[ax0, ax1, ax2]*b[ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=broadcast, attrs={})), stage(T_elemwise_sum, compute(T_elemwise_sum, body=[(T_add[ax0, ax1, ax2] + T_multiply[ax0, ax1, ax2])], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide, compute(T_divide, body=[(T_elemwise_sum[ax0, ax1, ax2]/2f)], axis=[iter_var(ax0, range(min=0, ext=100)), iter_var(ax1, range(min=0, ext=10)), iter_var(ax2, range(min=0, ext=10))], reduce_axis=[], tag=elemwise, attrs={})), stage(T_divide_red.rf, compute(T_divide_red.rf, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide[floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10), floormod((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10)]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], where=tir.likely((((floordiv(floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10), 10) < 100) && (floordiv((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)), 10) < 1000)) && ((k0.k1.fused.k2.fused.inner + (k0.k1.fused.k2.fused.outer*1024)) < 10000))), value_index=0)], axis=[iter_var(k0.k1.fused.k2.fused.inner, range(min=0, ext=1024))], reduce_axis=[iter_var(k0.k1.fused.k2.fused.outer, range(min=0, ext=10))], tag=, attrs={})), stage(T_divide_red, compute(T_divide_red.repl, body=[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]), source=[T_divide_red.rf[k0.k1.fused.k2.fused.inner.v]], init=[], axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], where=(bool)1, value_index=0)], axis=[], reduce_axis=[iter_var(k0.k1.fused.k2.fused.inner.v, range(min=0, ext=1024))], tag=, attrs={}))]\n"
          ]
        }
      ],
      "source": [
        "print(sg.stages)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "可以通过与 `numpy` 的结果进行比较来测试正确性，如下所示"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import mxnet_cudnn\n",
        "func = tvm.build(sg, [a, b, g], \"cuda\")\n",
        "dev = tvm.cuda(0)\n",
        "a_np = np.random.uniform(size=(x, y, y)).astype(a.dtype)\n",
        "b_np = np.random.uniform(size=(y, y)).astype(b.dtype)\n",
        "g_np = np.sum(np.add(a_np + b_np, a_np * b_np) / 2.0)\n",
        "a_nd = tvm.nd.array(a_np, dev)\n",
        "b_nd = tvm.nd.array(b_np, dev)\n",
        "g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev)\n",
        "func(a_nd, b_nd, g_nd)\n",
        "tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "TOPI 还提供常见的神经网络操作，如带有优化调度的 _softmax_"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(tarray_1: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {tarray: Buffer(tarray_2: Pointer(float32), float32, [262144], [])}\n",
            "  buffer_map = {tarray_1: tarray}\n",
            "  preflattened_buffer_map = {tarray_1: tarray_3: Buffer(tarray_2, float32, [512, 512], [])} {\n",
            "  allocate(T_softmax_norm: Pointer(global float32x4), float32x4, [65536]), storage_scope = global;\n",
            "  attr [IterVar(blockIdx.x: int32, (nullptr), \"ThreadIndex\", \"blockIdx.x\")] \"thread_extent\" = 512;\n",
            "  allocate(normal_reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local;\n",
            "  allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local;\n",
            "  allocate(T_softmax_exp: Pointer(warp float32), float32, [512]), storage_scope = warp;\n",
            "  allocate(normal_reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local;\n",
            "  allocate(reduce_temp0_1: Pointer(local float32), float32, [1]), storage_scope = local {\n",
            "    attr [IterVar(threadIdx.x: int32, [0:32], \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 32 {\n",
            "      normal_reduce_temp0_2: Buffer(normal_reduce_temp0, float32, [1], [], scope=\"local\")[0] = -3.40282e+38f32\n",
            "      for (k.inner: int32, 0, 16) {\n",
            "        normal_reduce_temp0_2[0] = max(normal_reduce_temp0_2[0], tarray[(((blockIdx.x*512) + (threadIdx.x*16)) + k.inner)])\n",
            "      }\n",
            "      attr [meta[tir.CommReducer][0]] \"reduce_scope\" = @tir.reinterpret(0u64, dtype=handle);\n",
            "      @tir.tvm_thread_allreduce(1u32, normal_reduce_temp0_2[0], True, reduce_temp0_2: Buffer(reduce_temp0, float32, [1], [], scope=\"local\")[0], threadIdx.x, dtype=handle)\n",
            "      for (i1.inner.outer: int32, 0, 4) {\n",
            "        let cse_var_1: int32 = (i1.inner.outer*4)\n",
            "        T_softmax_exp_1: Buffer(T_softmax_exp, float32, [512], [], scope=\"warp\")[ramp(((threadIdx.x*16) + cse_var_1), 1, 4)] = @tir.exp((tarray[ramp((((blockIdx.x*512) + (threadIdx.x*16)) + cse_var_1), 1, 4)] - broadcast(reduce_temp0_3: Buffer(reduce_temp0, float32, [1], [], scope=\"local\", align=4)[0], 4)), dtype=float32x4)\n",
            "      }\n",
            "    }\n",
            "    attr [IterVar(threadIdx.x, [0:32], \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 32 {\n",
            "      normal_reduce_temp0_3: Buffer(normal_reduce_temp0_1, float32, [1], [], scope=\"local\")[0] = 0f32\n",
            "      for (k.inner_1: int32, 0, 16) {\n",
            "        normal_reduce_temp0_3[0] = (normal_reduce_temp0_3[0] + T_softmax_exp_1[((threadIdx.x*16) + k.inner_1)])\n",
            "      }\n",
            "      attr [meta[tir.CommReducer][1]] \"reduce_scope\" = @tir.reinterpret(0u64, dtype=handle);\n",
            "      @tir.tvm_thread_allreduce(1u32, normal_reduce_temp0_3[0], True, reduce_temp0_4: Buffer(reduce_temp0_1, float32, [1], [], scope=\"local\")[0], threadIdx.x, dtype=handle)\n",
            "      for (i1.inner.outer_1: int32, 0, 4) {\n",
            "        T_softmax_norm_1: Buffer(T_softmax_norm, float32x4, [65536], [])[(((blockIdx.x*128) + (threadIdx.x*4)) + i1.inner.outer_1)] = (T_softmax_exp_1[ramp(((threadIdx.x*16) + (i1.inner.outer_1*4)), 1, 4)] / broadcast(reduce_temp0_5: Buffer(reduce_temp0_1, float32, [1], [], scope=\"local\", align=4)[0], 4))\n",
            "      }\n",
            "    }\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "tarray = te.placeholder((512, 512), name=\"tarray\")\n",
        "softmax_topi = topi.nn.softmax(tarray)\n",
        "with tvm.target.Target(\"cuda\"):\n",
        "    sst = topi.cuda.schedule_softmax(softmax_topi)\n",
        "    print(tvm.lower(sst, [tarray], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 融合卷积\n",
        "\n",
        "我们可以将 `topi.nn.conv2d` 和 `topi.nn.relu` 融合在一起。\n",
        "\n",
        "```{admonition} 注意\n",
        ":class: alert alert-info\n",
        "\n",
        "TOPI 函数都是通用函数。它们对不同的后端有不同的实现，以优化性能。对于每个后端，有必要在计算声明和时间表的目标范围内调用它们。TVM 会根据目标信息选择正确的函数来调用。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Cannot find config for target=cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32, workload=('conv2d_nchw.cuda', ('TENSOR', (1, 3, 224, 224), 'float32'), ('TENSOR', (10, 3, 5, 5), 'float32'), 1, 2, 1). A fallback configuration is used, which may bring great performance regression.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(placeholder_2: handle, placeholder_3: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {placeholder: Buffer(placeholder_4: Pointer(float32), float32, [150528], []),\n",
            "             placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [750], [])}\n",
            "  buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1}\n",
            "  preflattened_buffer_map = {placeholder_2: placeholder_6: Buffer(placeholder_4, float32, [1, 3, 224, 224], []), placeholder_3: placeholder_7: Buffer(placeholder_5, float32, [10, 3, 5, 5], [])} {\n",
            "  allocate(compute: Pointer(global float32), float32, [501760]), storage_scope = global;\n",
            "  attr [IterVar(blockIdx.z: int32, (nullptr), \"ThreadIndex\", \"blockIdx.z\")] \"thread_extent\" = 5;\n",
            "  allocate(conv2d_nchw: Pointer(local float32), float32, [14]), storage_scope = local;\n",
            "  allocate(pad_temp.shared: Pointer(shared float32), float32, [112]), storage_scope = shared;\n",
            "  allocate(placeholder.shared: Pointer(shared float32), float32, [2]), storage_scope = shared;\n",
            "  attr [IterVar(blockIdx.y: int32, (nullptr), \"ThreadIndex\", \"blockIdx.y\")] \"thread_extent\" = 224;\n",
            "  attr [IterVar(blockIdx.x: int32, (nullptr), \"ThreadIndex\", \"blockIdx.x\")] \"thread_extent\" = 2;\n",
            "  attr [IterVar(threadIdx.z: int32, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "  attr [IterVar(threadIdx.y: int32, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "  attr [IterVar(threadIdx.x: int32, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "    conv2d_nchw_1: Buffer(conv2d_nchw, float32, [4], [], scope=\"local\", align=8)[0] = 0f32\n",
            "    conv2d_nchw_1[2] = 0f32\n",
            "    conv2d_nchw_1[4] = 0f32\n",
            "    conv2d_nchw_1[6] = 0f32\n",
            "    conv2d_nchw_1[8] = 0f32\n",
            "    conv2d_nchw_1[10] = 0f32\n",
            "    conv2d_nchw_1[12] = 0f32\n",
            "    conv2d_nchw_1[1] = 0f32\n",
            "    conv2d_nchw_1[3] = 0f32\n",
            "    conv2d_nchw_1[5] = 0f32\n",
            "    conv2d_nchw_1[7] = 0f32\n",
            "    conv2d_nchw_1[9] = 0f32\n",
            "    conv2d_nchw_1[11] = 0f32\n",
            "    conv2d_nchw_1[13] = 0f32\n",
            "    for (rc.outer: int32, 0, 3) {\n",
            "      for (ry.outer: int32, 0, 5) {\n",
            "        attr [IterVar(threadIdx.z_1: int32, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_1: int32, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_1: int32, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "          pad_temp.shared_1: Buffer(pad_temp.shared, float32, [112], [], scope=\"shared\")[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv((threadIdx.x_1*7), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 450)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 1), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32)\n",
            "        }\n",
            "        attr [IterVar(threadIdx.z_2: int32, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_2: int32, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_2: int32, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16;\n",
            "        if @tir.likely((threadIdx.x_2 < 2), dtype=bool) {\n",
            "          placeholder.shared_1: Buffer(placeholder.shared, float32, [2], [], scope=\"shared\", align=8)[threadIdx.x_2] = placeholder_1[((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5))]\n",
            "        }\n",
            "        conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1]))\n",
            "        attr [IterVar(threadIdx.z_1, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_1, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_1, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "          pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (1 <= ((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 1), 2)))), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 449)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32)\n",
            "        }\n",
            "        attr [IterVar(threadIdx.z_2, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_2, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_2, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16;\n",
            "        if @tir.likely((threadIdx.x_2 < 2), dtype=bool) {\n",
            "          placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 1)]\n",
            "        }\n",
            "        conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1]))\n",
            "        attr [IterVar(threadIdx.z_1, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_1, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_1, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "          pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 448)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32)\n",
            "        }\n",
            "        attr [IterVar(threadIdx.z_2, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_2, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_2, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16;\n",
            "        if @tir.likely((threadIdx.x_2 < 2), dtype=bool) {\n",
            "          placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 2)]\n",
            "        }\n",
            "        conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1]))\n",
            "        attr [IterVar(threadIdx.z_1, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_1, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_1, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "          pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 447)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 9), 2)) < 113)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32)\n",
            "        }\n",
            "        attr [IterVar(threadIdx.z_2, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_2, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_2, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16;\n",
            "        if @tir.likely((threadIdx.x_2 < 2), dtype=bool) {\n",
            "          placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 3)]\n",
            "        }\n",
            "        conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1]))\n",
            "        attr [IterVar(threadIdx.z_1, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_1, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_1, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16 {\n",
            "          pad_temp.shared_1[(threadIdx.x_1*7)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 446)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 1)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 445)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 2)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 444)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 3)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 443)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 4)] = @tir.if_then_else(((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 442)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 5)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv(((threadIdx.x_1*7) + 9), 2)) < 113)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 441)], 0f32, dtype=float32)\n",
            "          pad_temp.shared_1[((threadIdx.x_1*7) + 6)] = @tir.if_then_else((((2 <= (blockIdx.y + ry.outer)) && ((blockIdx.y + ry.outer) < 226)) && (((blockIdx.x*56) + floordiv((threadIdx.x_1*7), 2)) < 108)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (ry.outer*224)) + (blockIdx.x*112)) + (threadIdx.x_1*7)) - 440)], 0f32, dtype=float32)\n",
            "        }\n",
            "        attr [IterVar(threadIdx.z_2, (nullptr), \"ThreadIndex\", \"threadIdx.z\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.y_2, (nullptr), \"ThreadIndex\", \"threadIdx.y\")] \"thread_extent\" = 1;\n",
            "        attr [IterVar(threadIdx.x_2, (nullptr), \"ThreadIndex\", \"threadIdx.x\")] \"thread_extent\" = 16;\n",
            "        if @tir.likely((threadIdx.x_2 < 2), dtype=bool) {\n",
            "          placeholder.shared_1[threadIdx.x_2] = placeholder_1[(((((blockIdx.z*150) + (threadIdx.x_2*75)) + (rc.outer*25)) + (ry.outer*5)) + 4)]\n",
            "        }\n",
            "        conv2d_nchw_1[0] = (conv2d_nchw_1[0] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[2] = (conv2d_nchw_1[2] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[4] = (conv2d_nchw_1[4] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[6] = (conv2d_nchw_1[6] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[8] = (conv2d_nchw_1[8] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[10] = (conv2d_nchw_1[10] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[12] = (conv2d_nchw_1[12] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[0]))\n",
            "        conv2d_nchw_1[1] = (conv2d_nchw_1[1] + (pad_temp.shared_1[threadIdx.x]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[3] = (conv2d_nchw_1[3] + (pad_temp.shared_1[(threadIdx.x + 16)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[5] = (conv2d_nchw_1[5] + (pad_temp.shared_1[(threadIdx.x + 32)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[7] = (conv2d_nchw_1[7] + (pad_temp.shared_1[(threadIdx.x + 48)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[9] = (conv2d_nchw_1[9] + (pad_temp.shared_1[(threadIdx.x + 64)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[11] = (conv2d_nchw_1[11] + (pad_temp.shared_1[(threadIdx.x + 80)]*placeholder.shared_1[1]))\n",
            "        conv2d_nchw_1[13] = (conv2d_nchw_1[13] + (pad_temp.shared_1[(threadIdx.x + 96)]*placeholder.shared_1[1]))\n",
            "      }\n",
            "    }\n",
            "    compute_1: Buffer(compute, float32, [501760], [])[((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x)] = max(conv2d_nchw_1[0], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 16)] = max(conv2d_nchw_1[2], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 32)] = max(conv2d_nchw_1[4], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 48)] = max(conv2d_nchw_1[6], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 64)] = max(conv2d_nchw_1[8], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 80)] = max(conv2d_nchw_1[10], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 96)] = max(conv2d_nchw_1[12], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50176)] = max(conv2d_nchw_1[1], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50192)] = max(conv2d_nchw_1[3], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50208)] = max(conv2d_nchw_1[5], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50224)] = max(conv2d_nchw_1[7], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50240)] = max(conv2d_nchw_1[9], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50256)] = max(conv2d_nchw_1[11], 0f32)\n",
            "    compute_1[(((((blockIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50272)] = max(conv2d_nchw_1[13], 0f32)\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "data = te.placeholder((1, 3, 224, 224))\n",
        "kernel = te.placeholder((10, 3, 5, 5))\n",
        "\n",
        "with tvm.target.Target(\"cuda\"):\n",
        "    conv = topi.cuda.conv2d_nchw(data, kernel, 1, 2, 1)\n",
        "    out = topi.nn.relu(conv)\n",
        "    sconv = topi.cuda.schedule_conv2d_nchw([out])\n",
        "    print(tvm.lower(sconv, [data, kernel], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 总结\n",
        "\n",
        "在本教程中，我们已经看到\n",
        "\n",
        "- 如何使用 TOPI API 进行 numpy 风格运算符的普通操作。\n",
        "- TOPI 如何为上下文的通用调度和运算符融合提供便利，以生成优化的内核代码。"
      ]
    }
  ],
  "metadata": {
    "interpreter": {
      "hash": "f0a0fcc4cb7375f8ee907b3c51d5b9d65107fda1aab037a85df7b0c09b870b98"
    },
    "kernelspec": {
      "display_name": "Python 3.10.4 (conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
